# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""R(2+1)D network."""

import math
from mindspore import nn
from mindspore import ops as P
from mindspore.common import initializer as init

from mindvision.engine.class_factory import ClassFactory, ModuleType
from mindvision.video.ops.adaptiveavgpool3d import AdaptiveAvgPool3D

__all__ = [
    'BasicBlock',
    'Conv2Plus1D',
    'R2Plus1dStem',
    'R2Plus1dNet',
    'R2Plus1d18',  # registration mechanism to use yaml configuration
    'R2Plus1d50',  # registration mechanism to use yaml configuration
]


class Conv2Plus1D(nn.SequentialCell):
    """R(2+1)d stage block."""

    def __init__(self,
                 in_planes,
                 out_planes,
                 midplanes,
                 stride=1,
                 padding=1):
        super(Conv2Plus1D, self).__init__([
            nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
                      stride=(1, stride, stride), pad_mode='pad',
                      padding=(0, 0, padding, padding, padding, padding),
                      has_bias=False),
            nn.BatchNorm3d(midplanes),
            nn.ReLU(),
            nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
                      stride=(stride, 1, 1), pad_mode='pad',
                      padding=(padding, padding, 0, 0, 0, 0),
                      has_bias=False)])

    @staticmethod
    def get_downsample_stride(stride):
        """Convert `stride` from int into tuple for downsample."""
        return (stride, stride, stride)


class BasicBlock(nn.Cell):
    """Basic block for R(2+1)d Network stage block."""
    expansion = 1

    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

        super(BasicBlock, self).__init__()
        self.conv1 = nn.SequentialCell(
            conv_builder(inplanes, planes, midplanes, stride),
            nn.BatchNorm3d(planes),
            nn.ReLU()
        )
        self.conv2 = nn.SequentialCell(
            conv_builder(planes, planes, midplanes),
            nn.BatchNorm3d(planes)
        )
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride

    def construct(self, x):
        """BasicBlock construct."""
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class R2Plus1dStem(nn.SequentialCell):
    """
    R(2+1)D stem that uses separated 3D convolution.
    """

    def __init__(self):
        super(R2Plus1dStem, self).__init__(
            [nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
                       stride=(1, 2, 2), pad_mode='pad', padding=(0, 0, 3, 3, 3, 3),
                       has_bias=False),
             nn.BatchNorm3d(45),
             nn.ReLU(),
             nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
                       stride=(1, 1, 1), pad_mode='pad', padding=(1, 1, 0, 0, 0, 0),
                       has_bias=False),
             nn.BatchNorm3d(64),
             nn.ReLU()])


class R2Plus1dNet(nn.Cell):
    """Generic R(2+1)d generator.

    Args:
        block (nn.Cell): resnet building block
        conv_makers (list(functions)): generator function for each layer
        layers (List[int]): number of blocks per layer
        stem (nn.Cell, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
        num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> from mindvision.video.models.backbones.r2plus1d import *
        >>> data = Tensor(np.random.randn(2, 3, 16, 112, 112), dtype=mindspore.float32)
        >>> net = R2Plus1dNet(block=BasicBlock,
                                conv_makers=[Conv2Plus1D] * 4,
                                layers=[3, 4, 6, 3],
                                stem=R2Plus1dStem)
        >>> predict = net(data)
        >>> print(predict.shape)
    """

    def __init__(self, block, conv_makers, layers,
                 stem, num_classes=400):

        super(R2Plus1dNet, self).__init__()
        self.inplanes = 64

        self.stem = stem()

        self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)

        self.avgpool = AdaptiveAvgPool3D((1, 1, 1))
        self.fc = nn.Dense(512 * block.expansion, num_classes)

        self.flatten = P.Flatten()
        # init weights
        self._initialize_weights()

    def construct(self, x):
        """VideoResNet construct."""
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        # Flatten the layer to fc
        x = self.flatten(x)
        x = self.fc(x)

        return x

    def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
        """Stage Block layers."""
        downsample = None

        if stride != 1 or self.inplanes != planes * block.expansion:
            ds_stride = conv_builder.get_downsample_stride(stride)
            downsample = nn.SequentialCell(
                nn.Conv3d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=ds_stride, has_bias=False),
                nn.BatchNorm3d(planes * block.expansion)
            )
        layers = []
        layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))

        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, conv_builder))

        return nn.SequentialCell(*layers)

    def _initialize_weights(self):
        """
        Init the weight of Conv3d and Dense in the net.
        """
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Conv3d):
                cell.weight.set_data(init.initializer(
                    init.HeNormal(math.sqrt(5), mode='fan_out', nonlinearity='relu'),
                    cell.weight.shape, cell.weight.dtype))
                if cell.bias:
                    cell.bias.set_data(init.initializer(
                        init.Zero(), cell.bias.shape, cell.bias.dtype))
            elif isinstance(cell, nn.BatchNorm2d):
                cell.gamma.set_data(init.initializer(
                    init.One(), cell.gamma.shape, cell.gamma.dtype))
                cell.beta.set_data(init.initializer(
                    init.Zero(), cell.beta.shape, cell.beta.dtype))
            elif isinstance(cell, nn.Dense):
                cell.weight.set_data(init.initializer(
                    init.Normal(0.01), cell.weight.shape, cell.weight.dtype))
                if cell.bias:
                    cell.bias.set_data(init.initializer(
                        init.Zero(), cell.bias.shape, cell.bias.dtype))


@ClassFactory.register(ModuleType.BACKBONE)
class R2Plus1d18(R2Plus1dNet):
    """
    The class of R2Plus1d-18 uses the registration mechanism to register,
    need to use the yaml configuration file to call.
    """

    def __init__(self, **kwargs):
        super(R2Plus1d18, self).__init__(block=BasicBlock,
                                         conv_makers=[Conv2Plus1D] * 4,
                                         layers=[2, 2, 2, 2],
                                         stem=R2Plus1dStem,
                                         **kwargs)


@ClassFactory.register(ModuleType.BACKBONE)
class R2Plus1d50(R2Plus1dNet):
    """
    The class of R2Plus1d-50 uses the registration mechanism to register,
    need to use the yaml configuration file to call.
    """

    def __init__(self, **kwargs):
        super(R2Plus1d50, self).__init__(block=BasicBlock,
                                         conv_makers=[Conv2Plus1D] * 4,
                                         layers=[3, 4, 6, 3],
                                         stem=R2Plus1dStem,
                                         **kwargs)
